#-*-Python-*-

import gin

import social_rl.multiagent_tfagents.multiagent_train_eval
import tf_agents.agents.ppo.ppo_agent

# Run function
xm_train.train_eval = @multiagent_train_eval.train_eval

# Special fixed agent
multiagent_train_eval.train_eval.fixed_agent_ids = [1]

# Environment
multiagent_train_eval.train_eval.env_name = 'MultiGrid-Cluttered-v0'

# Training params
multiagent_train_eval.train_eval.num_environment_steps=25000000
multiagent_train_eval.train_eval.replay_buffer_capacity = 4001
multiagent_train_eval.train_eval.collect_episodes_per_iteration = 30
multiagent_train_eval.train_eval.train_checkpoint_interval = 500
multiagent_train_eval.train_eval.policy_checkpoint_interval = 500
multiagent_train_eval.train_eval.log_interval = 15
multiagent_train_eval.train_eval.summary_interval = 50
multiagent_train_eval.train_eval.num_parallel_environments = 30
multiagent_train_eval.train_eval.num_eval_episodes = 10
multiagent_train_eval.train_eval.debug = False

# Architecture params
multiagent_train_eval.train_eval.actor_fc_layers = (32, 32)
multiagent_train_eval.train_eval.value_fc_layers=(32, 32)
multiagent_train_eval.train_eval.lstm_size=(256,)
multiagent_train_eval.train_eval.conv_filters=16
multiagent_train_eval.train_eval.conv_kernel=3
multiagent_train_eval.train_eval.direction_fc=5

# Agent params
ppo_agent.PPOAgent.discount_factor = 0.995
multiagent_train_eval.train_eval.entropy_regularization = .00
